{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "cb02b541",
   "metadata": {},
   "source": [
    "# MLP简单版"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ea74160e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import wandb # 参数可视化\n",
    "import pandas as pd\n",
    "from torch.nn import functional as F\n",
    "from tqdm import tqdm # 进度条\n",
    "import numpy as np\n",
    "from d2l import torch as d2l\n",
    "from torch.utils import data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a103d6a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_SAVE = 50\n",
    "net_list = \"in->256->64->1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "bdf69760",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(nn.Module):\n",
    "    def __init__(self, in_features):\n",
    "        super(MLP, self).__init__()\n",
    "        # 每层神经元数量:以漏斗形逐层递减\n",
    "        self.layer1 = nn.Linear(in_features, 256)\n",
    "        self.layer2 = nn.Linear(256, 64)\n",
    "        self.layer3 = nn.Linear(64, 16)\n",
    "        self.out = nn.Linear(16, 1)\n",
    "        \n",
    "    def forward(self, X):\n",
    "        X = F.relu(self.layer1(X))\n",
    "        X = F.relu(self.layer2(X))\n",
    "        X = F.relu(self.layer3(X))\n",
    "        return self.out(X)\n",
    "    \n",
    "device = torch.device('mps' if torch.backends.mps.is_available()\n",
    "                      else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6fd2abb2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def log_rmse(net, features, labels):\n",
    "    # 为了在取对数时进一步稳定该值，将小于1的值设置为1\n",
    "    clipped_preds = torch.clamp(net(features), 1, float('inf'))\n",
    "    rmse = torch.sqrt(criterion(torch.log(clipped_preds),\n",
    "                           torch.log(labels)))\n",
    "    return rmse.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c5f9c19f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((47439, 41), (31626, 40))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_data = pd.read_csv('../data/test.csv')\n",
    "train_data = pd.read_csv('../data/train.csv')\n",
    "train_data.shape, test_data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c88ae846",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 去掉冗余数据\n",
    "redundant_data = ['Address', 'Summary', 'City', 'State', 'Zip']\n",
    "for data in redundant_data:\n",
    "    del train_data[data], test_data[data]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "2311b511",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 数据预处理 , Sold Price这里可能有点问题,因为test里面没有Sold Price\n",
    "large_vel_cols = ['Lot', 'Total interior livable area', \n",
    "                  'Tax assessed value', 'Annual tax amount', \n",
    "                  'Listed Price', 'Last Sold Price']\n",
    "\n",
    "for data in large_vel_cols:\n",
    "    train_data[data] = np.log(train_data[data]+1)\n",
    "    #if data != 'Sold Price':\n",
    "    test_data[data] = np.log(test_data[data]+1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "c1823ff5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   Id  Sold Price          Type  Year built  Listed Price Last Sold On  \\\n",
      "0   0   3825000.0  SingleFamily      1969.0     15.250119          NaN   \n",
      "1   1    505000.0  SingleFamily      1926.0     13.171155   2019-08-30   \n",
      "2   2    140000.0  SingleFamily      1958.0     12.100718          NaN   \n",
      "3   3   1775000.0  SingleFamily      1947.0     14.454730   2016-08-30   \n",
      "\n",
      "   Last Sold Price  \n",
      "0              NaN  \n",
      "1        12.700772  \n",
      "2              NaN  \n",
      "3        14.220976  \n"
     ]
    }
   ],
   "source": [
    "print(train_data.iloc[0: 4, [0, 1, 2, 3, -3, -2, -1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "bdbcc118",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "      Id          Type  Year built      Heating  Listed Price Last Sold On  \\\n",
      "0  47439  SingleFamily      2020.0      Central     13.592243   2020-07-01   \n",
      "1  47440  SingleFamily      1924.0  Natural Gas     13.081439   2020-11-03   \n",
      "2  47441  SingleFamily      2020.0      Central     13.641039          NaN   \n",
      "3  47442  SingleFamily      2020.0      Central     13.604667   2020-09-21   \n",
      "\n",
      "   Last Sold Price  \n",
      "0        13.615841  \n",
      "1         9.615872  \n",
      "2              NaN  \n",
      "3        13.604791  \n"
     ]
    }
   ],
   "source": [
    "print(test_data.iloc[0: 4, [0, 1, 2, 3, -3, -2, -1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "935a95ce",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(79065, 34)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 把train和test去除id后放一起，train也要去掉label\n",
    "all_features = pd.concat((train_data.iloc[:,2:],test_data.iloc[:,1:])) # 无Sold Price版本\n",
    "\n",
    "# 时间数据赋日期格式\n",
    "all_features['Listed On'] = pd.to_datetime(all_features['Listed On'], format=\"%Y-%m-%d\")\n",
    "all_features['Last Sold On'] = pd.to_datetime(all_features['Last Sold On'], format=\"%Y-%m-%d\")\n",
    "\n",
    "all_features.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "66a3fd63",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Type                 174\n",
      "Heating              2660\n",
      "Cooling              911\n",
      "Parking              9913\n",
      "Bedrooms             278\n",
      "Region               1259\n",
      "Elementary School    3568\n",
      "Middle School        809\n",
      "High School          922\n",
      "Flooring             1740\n",
      "Heating features     1763\n",
      "Cooling features     596\n",
      "Appliances included  11290\n",
      "Laundry features     3031\n",
      "Parking features     9695\n"
     ]
    }
   ],
   "source": [
    "# 非数字的object\n",
    "for obj in all_features.dtypes[all_features.dtypes == 'object'].index: # 然后通过条件筛选出数据类型为 'object' 的列，即类别型特征列。\n",
    "    print(obj.ljust(20), len(all_features[obj].unique())) # 这部分代码用于获取每个类别型特征列中不同取值的数量，即唯一值的数量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "65a97ab0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 查询数字列 ->缺失数据赋0 -> 标准化\n",
    "numeric_features = all_features.dtypes[all_features.dtypes == 'float64'].index # 找到下标\n",
    "# python默认类型是float64\n",
    "all_features = all_features.fillna(method='bfill', axis=0).fillna(0)\n",
    "# 标准化\n",
    "all_features[numeric_features] = all_features[numeric_features].apply(lambda x: (x - x.mean()) / (x.std())) # mean均值 std 标准差"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "4371b7d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "features = list(numeric_features)\n",
    "features.extend(['Type', 'Bedrooms'])\n",
    "all_features = all_features[features] #使用 all_features 数据帧的列索引（列名），选择在 features 列表中包含的那些特征列。这将会将 all_features 数据帧限制为只包括你选择的特征列。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "100efa0e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "before one hot code (79065, 19)\n",
      "after one hot code (79065, 470)\n"
     ]
    }
   ],
   "source": [
    "'''\n",
    "这是 get_dummies 函数的一个参数，\n",
    "它表示是否要为缺失值（NaN）创建一个额外的虚拟列（dummy variable）。\n",
    "如果设置为 True，那么对于每个类别型特征列中的缺失值，会创建一个额外的虚拟列，\n",
    "用于表示缺失值的存在或缺失。如果设置为 False，缺失值将被忽略。\n",
    "'''\n",
    "print('before one hot code', all_features.shape)\n",
    "all_features = pd.get_dummies(all_features, dummy_na=True)\n",
    "all_features.shape\n",
    "print('after one hot code', all_features.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "ba8c7b48",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train feature shape: torch.Size([47439, 470])\n",
      "test feature shape: torch.Size([31626, 470])\n",
      "train label shape: torch.Size([47439, 1])\n"
     ]
    }
   ],
   "source": [
    "n_train = train_data.shape[0]\n",
    "train_features = torch.tensor(all_features[:n_train].values, dtype=torch.float32)\n",
    "print('train feature shape:', train_features.shape)\n",
    "test_features = torch.tensor(all_features[n_train:].values, dtype=torch.float32)\n",
    "print('test feature shape:', test_features.shape)\n",
    "train_labels = torch.tensor(train_data['Sold Price'].values.reshape(-1, 1), dtype=torch.float32)\n",
    "print('train label shape:', train_labels.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "53950db0",
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.MSELoss()\n",
    "in_features = train_features.shape[1]\n",
    "net = MLP(in_features).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "41c5744f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(net, train_features, train_labels, test_features, test_labels,\n",
    "          num_epochs, learning_rate, weight_decay, batch_size):\n",
    "    wandb.watch(net)\n",
    "    train_ls, test_ls = [], []\n",
    "    train_iter = d2l.load_array((train_features, train_labels), batch_size)\n",
    "    # 这里使用的是Adam优化算法\n",
    "    optimizer = torch.optim.Adam(net.parameters(), lr = learning_rate, weight_decay = weight_decay)\n",
    "    for epoch in tqdm(range(num_epochs)):\n",
    "        for X, y in train_iter:\n",
    "            X, y = X.to(device), y.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            outputs = net(X)\n",
    "            loss = criterion(outputs, y)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "        record_loss = log_rmse(net.to('cpu'), train_features, train_labels)\n",
    "        wandb.log({'loss': record_loss,'epoch': epoch})\n",
    "        train_ls.append(record_loss)\n",
    "        if (epoch%NUM_SAVE==0 and epoch!=0) or (epoch==num_epochs-1):\n",
    "            torch.save(net.state_dict(),'checkpoint_'+str(epoch))\n",
    "            print('save checkpoints on:', epoch, 'rmse loss value is:', record_loss)\n",
    "        del X, y\n",
    "        net.to(device)\n",
    "    wandb.finish()\n",
    "    return train_ls, test_ls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "55d5e0dd",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mzengchen\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ca9fbb583f254fee803b1fe03667e12f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(Label(value='Waiting for wandb.init()...\\r'), FloatProgress(value=0.011138449999999998, max=1.0…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.15.11"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Run data is saved locally in <code>/Users/zengchen/编程/日常代码/动手学深度学习/动手学深度学习/多层感知机MLP/wandb/run-20231003_152235-iff6wmy6</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Syncing run <strong><a href='https://wandb.ai/zengchen/calofornia_house_predict/runs/iff6wmy6' target=\"_blank\">fiery-sunset-12</a></strong> to <a href='https://wandb.ai/zengchen/calofornia_house_predict' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View project at <a href='https://wandb.ai/zengchen/calofornia_house_predict' target=\"_blank\">https://wandb.ai/zengchen/calofornia_house_predict</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run at <a href='https://wandb.ai/zengchen/calofornia_house_predict/runs/iff6wmy6' target=\"_blank\">https://wandb.ai/zengchen/calofornia_house_predict/runs/iff6wmy6</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "network: MLP(\n",
      "  (layer1): Linear(in_features=470, out_features=256, bias=True)\n",
      "  (layer2): Linear(in_features=256, out_features=64, bias=True)\n",
      "  (layer3): Linear(in_features=64, out_features=16, bias=True)\n",
      "  (out): Linear(in_features=16, out_features=1, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "k, num_epochs, lr, weight_decay, batch_size = 5, 2000, 0.005, 0.05, 256\n",
    "wandb.init(project=\"calofornia_house_predict\",\n",
    "           config={ \"learning_rate\": lr,\n",
    "                    \"weight_decay\": weight_decay,\n",
    "                    \"batch_size\": batch_size,\n",
    "                    \"total_run\": num_epochs,\n",
    "                    \"network\": net_list}\n",
    "          )\n",
    "print(\"network:\",net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "92b14999",
   "metadata": {},
   "outputs": [],
   "source": [
    "net.to('cpu')\n",
    "preds = net(test_features).detach().numpy()\n",
    "# 将其重新格式化以导出到Kaggle\n",
    "test_data['Sold Price'] = pd.Series(preds.reshape(1, -1)[0])\n",
    "submission = pd.concat([test_data['Id'], test_data['Sold Price']], axis=1)\n",
    "submission.to_csv('submission.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "e6c5d204",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "Finishing last run (ID:iff6wmy6) before initializing another..."
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Waiting for W&B process to finish... <strong style=\"color:green\">(success).</strong>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1b5abde8b8294e8d89d4fce4579e6a0b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\\r'), FloatProgress(value=1.0, max…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run <strong style=\"color:#cdcd00\">fiery-sunset-12</strong> at: <a href='https://wandb.ai/zengchen/calofornia_house_predict/runs/iff6wmy6' target=\"_blank\">https://wandb.ai/zengchen/calofornia_house_predict/runs/iff6wmy6</a><br/> View job at <a href='https://wandb.ai/zengchen/calofornia_house_predict/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEwMTU1NzM3OA==/version_details/v3' target=\"_blank\">https://wandb.ai/zengchen/calofornia_house_predict/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEwMTU1NzM3OA==/version_details/v3</a><br/>Synced 6 W&B file(s), 0 media file(s), 2 artifact file(s) and 0 other file(s)"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Find logs at: <code>./wandb/run-20231003_152235-iff6wmy6/logs</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Successfully finished last run (ID:iff6wmy6). Initializing new run:<br/>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ecee894cbd904994879ca06739dcb42a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(Label(value='Waiting for wandb.init()...\\r'), FloatProgress(value=0.011144943055555561, max=1.0…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.15.11"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Run data is saved locally in <code>/Users/zengchen/编程/日常代码/动手学深度学习/动手学深度学习/多层感知机MLP/wandb/run-20231003_152243-03qeayrd</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Syncing run <strong><a href='https://wandb.ai/zengchen/kaggle_predict/runs/03qeayrd' target=\"_blank\">expert-elevator-9</a></strong> to <a href='https://wandb.ai/zengchen/kaggle_predict' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View project at <a href='https://wandb.ai/zengchen/kaggle_predict' target=\"_blank\">https://wandb.ai/zengchen/kaggle_predict</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run at <a href='https://wandb.ai/zengchen/kaggle_predict/runs/03qeayrd' target=\"_blank\">https://wandb.ai/zengchen/kaggle_predict/runs/03qeayrd</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "network: MLP(\n",
      "  (layer1): Linear(in_features=470, out_features=256, bias=True)\n",
      "  (layer2): Linear(in_features=256, out_features=64, bias=True)\n",
      "  (layer3): Linear(in_features=64, out_features=16, bias=True)\n",
      "  (out): Linear(in_features=16, out_features=1, bias=True)\n",
      ")\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|████▎                                     | 51/500 [01:04<08:51,  1.18s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 50 rmse loss value is: 0.2818559408187866\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|████████▎                                | 101/500 [02:04<07:47,  1.17s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 100 rmse loss value is: 0.23824316263198853\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|████████████▍                            | 151/500 [03:05<07:05,  1.22s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 150 rmse loss value is: 0.2408779263496399\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████████████████▍                        | 201/500 [04:05<06:14,  1.25s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 200 rmse loss value is: 0.2566238045692444\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|████████████████████▌                    | 251/500 [05:05<04:59,  1.20s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 250 rmse loss value is: 0.2355775386095047\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|████████████████████████▋                | 301/500 [06:06<03:59,  1.20s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 300 rmse loss value is: 0.2208111584186554\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|████████████████████████████▊            | 351/500 [07:06<02:59,  1.21s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 350 rmse loss value is: 0.21300695836544037\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|████████████████████████████████▉        | 401/500 [08:06<01:57,  1.19s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 400 rmse loss value is: 0.20981846749782562\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 90%|████████████████████████████████████▉    | 451/500 [09:06<00:59,  1.21s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 450 rmse loss value is: 0.20845922827720642\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████| 500/500 [10:05<00:00,  1.21s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save checkpoints on: 499 rmse loss value is: 0.2053406983613968\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "Waiting for W&B process to finish... <strong style=\"color:green\">(success).</strong>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<style>\n",
       "    table.wandb td:nth-child(1) { padding: 0 10px; text-align: left ; width: auto;} td:nth-child(2) {text-align: left ; width: 100%}\n",
       "    .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n",
       "    .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n",
       "    </style>\n",
       "<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███</td></tr><tr><td>loss</td><td>█▃▃▂▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>499</td></tr><tr><td>loss</td><td>0.20534</td></tr></table><br/></div></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run <strong style=\"color:#cdcd00\">expert-elevator-9</strong> at: <a href='https://wandb.ai/zengchen/kaggle_predict/runs/03qeayrd' target=\"_blank\">https://wandb.ai/zengchen/kaggle_predict/runs/03qeayrd</a><br/> View job at <a href='https://wandb.ai/zengchen/kaggle_predict/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEwMTU1NzY5MQ==/version_details/v4' target=\"_blank\">https://wandb.ai/zengchen/kaggle_predict/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEwMTU1NzY5MQ==/version_details/v4</a><br/>Synced 6 W&B file(s), 0 media file(s), 2 artifact file(s) and 0 other file(s)"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Find logs at: <code>./wandb/run-20231003_152243-03qeayrd/logs</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 读取已有 继续进行训练\n",
    "k, num_epochs, lr, weight_decay, batch_size = 5, 500, 0.01, 0.01, 256\n",
    "wandb.init(project=\"kaggle_predict\",\n",
    "           config={ \"learning_rate\": lr,\n",
    "                    \"weight_decay\": weight_decay,\n",
    "                    \"batch_size\": batch_size,\n",
    "                    \"total_run\": num_epochs,\n",
    "                    \"network\": net_list}\n",
    "          )\n",
    "#net.load_state_dict(torch.load('checkpoint_19676'))\n",
    "print(\"network:\",net)\n",
    "net.to(device)\n",
    "train_ls, valid_ls = train(net, train_features,train_labels, None, None, num_epochs, lr, weight_decay, batch_size)\n",
    "net.to('cpu')\n",
    "preds = net(test_features).detach().numpy()\n",
    "# 将其重新格式化以导出到Kaggle\n",
    "test_data['Sold Price'] = pd.Series(preds.reshape(1, -1)[0])\n",
    "submission = pd.concat([test_data['Id'], test_data['Sold Price']], axis=1)\n",
    "submission.to_csv('submission.csv', index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python(d2l)",
   "language": "python",
   "name": "d2l"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
